In [33]:
import sys
sys.setrecursionlimit(10000)
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)

import os
os.environ['GNUMPY_IMPLICIT_CONVERSION'] = 'ignore'
print os.environ.get('GNUMPY_IMPLICIT_CONVERSION')
ignore

In [2]:
import cPickle
import gzip

from breze.learn.data import one_hot
from breze.learn.base import cast_array_to_local_type
from breze.learn.utils import tile_raster_images

import climin.stops
import climin.initialize
from climin import mathadapt as ma

from breze.learn import hvi
from breze.learn.hvi import HmcViModel
from breze.learn.hvi.energies import (NormalGaussKinEnergyMixin, DiagGaussKinEnergyMixin)
from breze.learn.hvi.inversemodels import MlpGaussInvModelMixin

from matplotlib import pyplot as plt
from matplotlib import cm

import numpy as np

#import fasttsne

from IPython.html import widgets
%matplotlib inline

import theano
theano.config.compute_test_value = 'ignore'#'raise'
from theano import (tensor as T, clone)
C:\Anaconda\lib\site-packages\scipy\lib\_util.py:35: DeprecationWarning: Module scipy.linalg.blas.fblas is deprecated, use scipy.linalg.blas instead
  DeprecationWarning)
C:\Anaconda\lib\site-packages\scipy\lib\_util.py:35: DeprecationWarning: Module scipy.linalg.blas.fblas is deprecated, use scipy.linalg.blas instead
  DeprecationWarning)
C:\Anaconda\lib\site-packages\scipy\lib\_util.py:35: DeprecationWarning: Module scipy.linalg.blas.fblas is deprecated, use scipy.linalg.blas instead
  DeprecationWarning)
C:\Anaconda\lib\site-packages\scipy\lib\_util.py:35: DeprecationWarning: Module scipy.linalg.blas.fblas is deprecated, use scipy.linalg.blas instead
  DeprecationWarning)
C:\Anaconda\lib\site-packages\scipy\lib\_util.py:35: DeprecationWarning: Module scipy.linalg.blas.fblas is deprecated, use scipy.linalg.blas instead
  DeprecationWarning)
C:\Anaconda\lib\site-packages\scipy\lib\_util.py:35: DeprecationWarning: Module scipy.linalg.blas.fblas is deprecated, use scipy.linalg.blas instead
  DeprecationWarning)
C:\Anaconda\lib\site-packages\scipy\lib\_util.py:35: DeprecationWarning: Module scipy.linalg.blas.fblas is deprecated, use scipy.linalg.blas instead
  DeprecationWarning)
C:\Anaconda\lib\site-packages\scipy\lib\_util.py:35: DeprecationWarning: Module scipy.linalg.blas.fblas is deprecated, use scipy.linalg.blas instead
  DeprecationWarning)
Using gpu device 0: Quadro K2200 (CNMeM is disabled)

In [3]:
datafile = '../mnist.pkl.gz'
# Load data.                                                                                                   

with gzip.open(datafile,'rb') as f:                                                                        
    train_set, val_set, test_set = cPickle.load(f)                                                       

X, Z = train_set                                                                                               
VX, VZ = val_set
TX, TZ = test_set

Z = one_hot(Z, 10)
VZ = one_hot(VZ, 10)
TZ = one_hot(TZ, 10)

X_no_bin = X
VX_no_bin = VX
TX_no_bin = TX

# binarize the MNIST data
np.random.seed(0)
X  = np.random.binomial(1, X) * 1.0
VX = np.random.binomial(1, VX) * 1.0
TX = np.random.binomial(1, TX) * 1.0

image_dims = 28, 28

X_np, Z_np, VX_np, VZ_np, TX_np, TZ_np, X_no_bin_np, VX_no_bin_np, TX_no_bin_np = X, Z, VX, VZ, TX, TZ, X_no_bin, VX_no_bin, TX_no_bin
X, Z, VX, VZ, TX, TZ, X_no_bin, VX_no_bin, TX_no_bin = [cast_array_to_local_type(i) 
                                                        for i in (X, Z, VX,VZ, TX, TZ, X_no_bin, VX_no_bin, TX_no_bin)]
print X.shape
(50000L, 784L)

\\srv-file.brml.tum.de\nthome\cwolf\code\2015-christopherwolf-msc\breze\learn\base.py:39: UserWarning: Implicilty converting numpy.ndarray to gnumpy.garray
  warnings.warn('Implicilty converting numpy.ndarray to gnumpy.garray')

In [4]:
fig, ax = plt.subplots(figsize=(9, 9))

img = tile_raster_images(X_np[:64], image_dims, (8, 8), (1, 1))
ax.imshow(img, cmap=cm.binary)
Out[4]:
<matplotlib.image.AxesImage at 0x24576320>
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [80]:
fast_dropout = False

if fast_dropout:
    class MyHmcViModel(HmcViModel, 
                   hvi.FastDropoutMlpBernoulliVisibleVAEMixin, 
                   hvi.FastDropoutMlpGaussLatentVAEMixin, 
                   DiagGaussKinEnergyMixin,
                   MlpGaussInvModelMixin):
        pass

    kwargs = {
        'p_dropout_inpt': .1,
        'p_dropout_hiddens': [.2, .2],
    }

    print 'yeah'

else:
    class MyHmcViModel(HmcViModel, 
                   hvi.MlpBernoulliVisibleVAEMixin, 
                   hvi.MlpGaussLatentVAEMixin, 
                   DiagGaussKinEnergyMixin,
                   MlpGaussInvModelMixin):
        pass
    kwargs = {}


batch_size = 500
#optimizer = 'rmsprop', {'step_rate': 1e-4, 'momentum': 0.95, 'decay': .95, 'offset': 1e-6}
#optimizer = 'adam', {'step_rate': .5, 'momentum': 0.9, 'decay': .95, 'offset': 1e-6}
optimizer = 'adam', {'step_rate': 0.0001}

# This is the number of random variables NOT the size of 
# the sufficient statistics for the random variables.
n_latents = 2
n_hidden = 200

m = MyHmcViModel(X.shape[1], n_latents, 
                 [n_hidden, n_hidden], ['rectifier'] * 2, 
                 [n_hidden, n_hidden], ['rectifier'] * 2, 
                 [n_hidden], ['rectifier'] * 1,
                 n_hmc_steps=3, n_lf_steps=4,
                 n_z_samples=1,
          optimizer=optimizer, batch_size=batch_size, allow_partial_velocity_update=False, perform_acceptance_step=False,
          compute_transition_densities=True,
          **kwargs)

#climin.initialize.randomize_normal(m.parameters.data, 0.1, 1e-1)
#m.parameters.__setitem__(m.hmc_sampler.step_size_param, 0.2)
#m.parameters.__setitem__(m.kin_energy.mlp.layers[-1].bias, 1)
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [81]:
old_best_params = None
#print m.score(TX)
print m.parameters.data.shape
(554795,)

In [82]:
FILENAME = 'hvi_gen2_recog2_aux1_late2_hid200_pretrained.pkl'

# In[5]:
#old_best_params = None
f = open(FILENAME, 'rb')
np_array = cPickle.load(f)
old_best_params = cast_array_to_local_type(np_array)
f.close()
print old_best_params.shape
(554795,)

In [83]:
m.parameters.data = old_best_params.copy()
#old_best_loss = m.score(VX)
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [84]:
print m.score(VX)
print m.score(TX)
compiled score function
garray(129.34292602539062)
garray(130.10096740722656)

In [56]:
print m.parameters.view(m.init_recog.mlp.layers[2].bias)
garray([-0.38601121, -0.04792288, -9.10118866, -8.94394207])

In [174]:
#m.parameters.__setitem__(m.hmc_sampler.step_size_param, 0.2)
#m.parameters.__setitem__(m.init_recog.mlp.layers[2].bias, cast_array_to_local_type(np.array([-0.34611687, -0.09502647, -1.77520561, -2.25207138])))
#m.parameters.__setitem__(m.kin_energy.variance_parameter, cast_array_to_local_type(np.array([-0.7, -0.7])))
In [57]:
print 0.1 * m.parameters.view(m.hmc_sampler.step_size_param) ** 2 + 1e-8
garray([ 0.00400001])

In [87]:
#print m.estimate_nll(TX, 500)
142.505125

In [58]:
print m.score(VX_no_bin)
print m.score(TX_no_bin)
garray(129.36375427246094)
garray(130.09336853027344)

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [85]:
TARGET_FILENAME = 'hvi_gen2_recog2_aux1_late2_hid200_pretr_3hmc_04lf'
FILETYPE_EXTENSION = '.pkl'
old_best_params = None

max_passes = 500
max_iter = max_passes * X.shape[0] / batch_size
n_report = X.shape[0] / batch_size

stop = climin.stops.AfterNIterations(max_iter)
pause = climin.stops.ModuloNIterations(n_report)

# theano.config.optimizer = 'fast_compile'

for i, info in enumerate(m.powerfit((X_no_bin,), (VX,), stop, pause, eval_train_loss=False)):
    print i, info['loss'], info['val_loss'], np.exp(m.parameters.view(m.kin_energy.variance_parameter).as_numpy_array()), 0.1*m.parameters.view(m.hmc_sampler.step_size_param).as_numpy_array()**2 + 1e-8
    if i == 0 and old_best_params is not None:
        if info['best_loss'] > old_best_loss:
            info['best_loss'] = old_best_loss
            info['best_pars'] = old_best_params
    
    if info['best_loss'] == info['val_loss']:
        f = open(TARGET_FILENAME + FILETYPE_EXTENSION, 'wb')
        cPickle.dump(m.parameters.data, f, protocol=cPickle.HIGHEST_PROTOCOL)
        f.close()
Compiled loss functions
0 0.0 129.171478271 [[ 0.49994591  0.50026412]] [ 0.00369003]
1 0.0 129.132247925 [[ 0.50255068  0.50236616]] [ 0.00348206]
2 0.0 129.130050659 [[ 0.50402062  0.50451102]] [ 0.00328921]
3 0.0 129.070907593 [[ 0.50470796  0.5061452 ]] [ 0.00314143]
4 0.0 129.072753906 [[ 0.50430868  0.50724913]] [ 0.00304434]
5 0.0 128.989273071 [[ 0.50306874  0.50797879]] [ 0.00298369]
6 0.0 129.089767456 [[ 0.50136266  0.50890538]] [ 0.00292767]
7 0.0 128.982925415 [[ 0.4992255   0.50948319]] [ 0.00290191]
8 0.0 129.019470215 [[ 0.49695905  0.51027062]] [ 0.0028739]
9 0.0 128.961639404 [[ 0.49481898  0.51093198]] [ 0.00284737]
10 0.0 128.922821045 [[ 0.49264585  0.51152844]] [ 0.00282793]
11 0.0 129.010528564 [[ 0.48998067  0.51208499]] [ 0.00284389]
12 0.0 129.058380127 [[ 0.48742633  0.5129238 ]] [ 0.00283086]
13 0.0 128.965408325 [[ 0.48489876  0.5135174 ]] [ 0.00284226]
14 0.0 128.994338989 [[ 0.48226989  0.51424185]] [ 0.00285277]
15 0.0 128.896347046 [[ 0.47977321  0.51494114]] [ 0.0028652]
16 0.0 128.909088135 [[ 0.47751956  0.51584123]] [ 0.00284899]
17 0.0 129.044891357 [[ 0.47527793  0.51645185]] [ 0.00286824]
18 0.0 128.99432373 [[ 0.472854    0.51777064]] [ 0.00284476]
19 0.0 128.993545532 [[ 0.47059548  0.51848833]] [ 0.00286182]
20 0.0 128.862213135 [[ 0.46837092  0.51896401]] [ 0.00289598]
21 0.0 129.02432251 [[ 0.46619822  0.5198922 ]] [ 0.00290645]
22 0.0 129.016464233 [[ 0.46423791  0.52048298]] [ 0.00291552]
23 0.0 128.985565186 [[ 0.46215724  0.52186488]] [ 0.00288466]
24 0.0 129.031204224 [[ 0.46034009  0.52249244]] [ 0.00290199]
25 0.0 129.001968384 [[ 0.4586508   0.52370584]] [ 0.00287562]
26 0.0 129.009216309 [[ 0.45701766  0.52410323]] [ 0.00288914]
27 0.0 129.037353516 [[ 0.45509937  0.52492268]] [ 0.00290971]
28 0.0 128.938903809 [[ 0.45319753  0.52626379]] [ 0.0028939]
29 0.0 129.03237915 [[ 0.45112399  0.52772114]] [ 0.00288354]
30 0.0 128.970291138 [[ 0.44953542  0.52840695]] [ 0.0029059]
31 0.0 129.049453735 [[ 0.44789837  0.52932266]] [ 0.00291861]
32 0.0 128.926589966 [[ 0.44645246  0.53013115]] [ 0.00291648]
33 0.0 128.906265259 [[ 0.44485645  0.53136992]] [ 0.00289839]
34 0.0 128.910064697 [[ 0.44343064  0.5323511 ]] [ 0.0028977]
35 0.0 128.863922119 [[ 0.44145685  0.53356555]] [ 0.00289805]
36 0.0 128.937423706 [[ 0.44020705  0.53420785]] [ 0.00291265]
37 0.0 128.941589355 [[ 0.43862236  0.53508988]] [ 0.00292088]
38 0.0 129.093139648 [[ 0.43697971  0.53640632]] [ 0.00289987]
39 0.0 128.984329224 [[ 0.4353661  0.5371824]] [ 0.00291807]
40 0.0 129.068145752 [[ 0.4341404   0.53837227]] [ 0.00289828]
41 0.0 129.058837891 [[ 0.43277323  0.53895338]] [ 0.00292209]
42 0.0 129.034927368 [[ 0.43170644  0.54012824]] [ 0.00289131]
43 0.0 128.983840942 [[ 0.43060009  0.54120067]] [ 0.00287128]
44 0.0 129.072769165 [[ 0.4291563   0.54232684]] [ 0.00288007]
45 0.0 128.983276367 [[ 0.4276082   0.54312854]] [ 0.00290635]
46 0.0 128.911224365 [[ 0.42664698  0.54379532]] [ 0.00290718]
47 0.0 128.969726562 [[ 0.42538566  0.54519904]] [ 0.00288162]
48 0.0 128.922332764 [[ 0.42433832  0.54653807]] [ 0.00285148]
49 0.0 128.859405518 [[ 0.42299272  0.5468813 ]] [ 0.00289478]
50 0.0 128.935562134 [[ 0.42150572  0.54794724]] [ 0.00290698]
51 0.0 128.954528809 [[ 0.42071759  0.54863148]] [ 0.00290379]
52 0.0 128.98399353 [[ 0.42003443  0.54960369]] [ 0.00287627]
53 0.0 128.936538696 [[ 0.4181618   0.55095232]] [ 0.00291159]
54 0.0 128.91456604 [[ 0.41711094  0.55188308]] [ 0.00291542]
55 0.0 128.905517578 [[ 0.41614015  0.55317825]] [ 0.00287925]
56 0.0 128.929916382 [[ 0.41510723  0.55421154]] [ 0.00287184]
57 0.0 128.865936279 [[ 0.41402116  0.55400512]] [ 0.00294681]
58 0.0 128.831451416 [[ 0.41309144  0.55527545]] [ 0.00291194]
59 0.0 128.91027832 [[ 0.41195995  0.55624958]] [ 0.00291853]
60 0.0 128.874481201 [[ 0.41124973  0.55766355]] [ 0.0028672]
61 0.0 128.891143799 [[ 0.41031335  0.5581972 ]] [ 0.00288576]
62 0.0 128.735153198 [[ 0.40965536  0.55882967]] [ 0.0028829]
63 0.0 128.918395996 [[ 0.4087496   0.55936364]] [ 0.00290588]
64 0.0 128.868667603 [[ 0.40778116  0.56068061]] [ 0.00287706]
65 0.0 128.842727661 [[ 0.4065388   0.56157769]] [ 0.00289166]
66 0.0 128.781890869 [[ 0.40594773  0.56286539]] [ 0.00285824]
67 0.0 128.800628662 [[ 0.40473096  0.56402393]] [ 0.00286516]
68 0.0 128.817672729 [[ 0.40415353  0.56458234]] [ 0.00285939]
69 0.0 128.83682251 [[ 0.40326347  0.56522852]] [ 0.0028721]
70 0.0 128.727722168 [[ 0.40236038  0.56585827]] [ 0.00288307]
71 0.0 128.789154053 [[ 0.40127185  0.56689382]] [ 0.00289941]
72 0.0 128.757827759 [[ 0.40012802  0.56768951]] [ 0.00291337]
73 0.0 128.907974243 [[ 0.39941743  0.56932758]] [ 0.00286702]
74 0.0 128.895935059 [[ 0.39864127  0.57036999]] [ 0.00285764]
75 0.0 128.835876465 [[ 0.39761524  0.570594  ]] [ 0.00290885]
76 0.0 128.870803833 [[ 0.39725424  0.57157426]] [ 0.00287289]
77 0.0 129.031326294 [[ 0.39668221  0.5724306 ]] [ 0.00286407]
78 0.0 128.953948975 [[ 0.39589737  0.57319802]] [ 0.00286457]
79 0.0 128.902191162 [[ 0.39552654  0.57340243]] [ 0.00287599]
80 0.0 128.86680603 [[ 0.39499876  0.57325289]] [ 0.0029058]
81 0.0 128.869873047 [[ 0.39441486  0.57399076]] [ 0.00289678]
82 0.0 128.932815552 [[ 0.39401821  0.5750865 ]] [ 0.00285897]
83 0.0 128.890899658 [[ 0.39305534  0.57561181]] [ 0.00289576]
84 0.0 128.868301392 [[ 0.39270417  0.57672508]] [ 0.00284703]
85 0.0 128.772201538 [[ 0.39165523  0.57732304]] [ 0.00287972]
86 0.0 128.906204224 [[ 0.39098307  0.57819341]] [ 0.00287274]
87 0.0 128.866470337 [[ 0.39038267  0.57898712]] [ 0.00286484]
88 0.0 128.908050537 [[ 0.38988097  0.57921811]] [ 0.00287861]
89 0.0 128.954116821 [[ 0.38945547  0.57965193]] [ 0.00287694]
90 0.0 128.800262451 [[ 0.389137    0.58002861]] [ 0.00288378]
91 0.0 128.92288208 [[ 0.38859685  0.58083128]] [ 0.00286514]
92 0.0 128.908432007 [[ 0.38800051  0.58071185]] [ 0.00291154]
93 0.0 128.88609314 [[ 0.38757896  0.58063595]] [ 0.00293364]
94 0.0 128.869003296 [[ 0.38751798  0.58156812]] [ 0.00288131]
95 0.0 128.860900879 [[ 0.38678266  0.58228704]] [ 0.00288078]
96 0.0 128.903366089 [[ 0.38628568  0.58282011]] [ 0.00288004]
97 0.0 128.904571533 [[ 0.38584994  0.58328447]] [ 0.00288296]
98 0.0 128.930923462 [[ 0.38561887  0.58330287]] [ 0.0028944]
99 0.0 128.881103516 [[ 0.38526236  0.58379243]] [ 0.00289723]
100 0.0 129.006240845 [[ 0.38469441  0.58442394]] [ 0.00289198]
101 0.0 128.954452515 [[ 0.38443698  0.58450104]] [ 0.00290158]
102 0.0 128.935440063 [[ 0.384077    0.58469823]] [ 0.00290237]
103 0.0 128.9584198 [[ 0.38353524  0.58556285]] [ 0.00288461]
104 0.0 128.952026367 [[ 0.38291884  0.58626697]] [ 0.00288531]
105 0.0 128.855697632 [[ 0.38237616  0.5865124 ]] [ 0.00290353]
106 0.0 128.871368408 [[ 0.38219039  0.58697887]] [ 0.00289124]
107 0.0 128.850540161 [[ 0.3817488   0.58821194]] [ 0.00284165]
108 0.0 128.898452759 [[ 0.38116235  0.58833992]] [ 0.00288095]
109 0.0 128.923278809 [[ 0.38070473  0.58878973]] [ 0.00288125]
110 0.0 128.96875 [[ 0.38055238  0.58911536]] [ 0.00287603]
111 0.0 128.893325806 [[ 0.38038318  0.58924596]] [ 0.0028744]
112 0.0 128.938095093 [[ 0.37999511  0.58982333]] [ 0.00286923]
113 0.0 128.899276733 [[ 0.37939309  0.59036952]] [ 0.00287869]
114 0.0 128.918914795 [[ 0.37912198  0.5910628 ]] [ 0.00285744]
115 0.0 128.872741699 [[ 0.37871676  0.59059939]] [ 0.00290906]
116 0.0 128.862289429 [[ 0.37854585  0.5909468 ]] [ 0.00288874]
117 0.0 128.84262085 [[ 0.37806126  0.59163563]] [ 0.00288431]
118 0.0 128.887207031 [[ 0.37780339  0.59171714]] [ 0.00289161]
119 0.0 128.789077759 [[ 0.37769735  0.59183692]] [ 0.00289151]
120 0.0 128.903900146 [[ 0.37787179  0.5917663 ]] [ 0.00288551]
121 0.0 128.941970825 [[ 0.3773057   0.59201047]] [ 0.00289446]
122 0.0 128.844192505 [[ 0.37703253  0.59208754]] [ 0.00290802]
123 0.0 128.98651123 [[ 0.37674404  0.5922313 ]] [ 0.00291257]
124 0.0 128.95854187 [[ 0.37645071  0.59269179]] [ 0.0029035]
125 0.0 128.956390381 [[ 0.37639614  0.59312834]] [ 0.00288366]
126 0.0 128.903564453 [[ 0.37589951  0.59312311]] [ 0.00291978]
127 0.0 128.904144287 [[ 0.37564514  0.59360446]] [ 0.00290726]
128 0.0 128.93510437 [[ 0.37595093  0.59361207]] [ 0.00288203]
129 0.0 128.897415161 [[ 0.37571426  0.5935516 ]] [ 0.00289505]
130 0.0 128.830490112 [[ 0.37551749  0.59409806]] [ 0.00287627]
131 0.0 129.001220703 [[ 0.37505439  0.59440579]] [ 0.00289138]
132 0.0 128.967803955 [[ 0.37488605  0.59468805]] [ 0.00289254]
133 0.0 128.876602173 [[ 0.37489001  0.59482752]] [ 0.0028799]
134 0.0 128.766815186 [[ 0.37486085  0.59456291]] [ 0.00289417]
135 0.0 128.799865723 [[ 0.37478721  0.59502116]] [ 0.00287734]
136 0.0 128.880523682 [[ 0.37418657  0.59502705]] [ 0.0029125]
137 0.0 128.961883545 [[ 0.37384344  0.59501255]] [ 0.00293852]
138 0.0 128.904586792 [[ 0.37390962  0.59633059]] [ 0.00285972]
139 0.0 128.944641113 [[ 0.37361885  0.59591512]] [ 0.00289784]
140 0.0 128.891387939 [[ 0.37316761  0.59594815]] [ 0.00292532]
141 0.0 128.896194458 [[ 0.37340052  0.59658678]] [ 0.00287676]
142 0.0 128.90296936 [[ 0.37257656  0.5971912 ]] [ 0.00291006]
143 0.0 128.871520996 [[ 0.37257698  0.59738387]] [ 0.00289943]
144 0.0 128.999954224 [[ 0.37259972  0.5973054 ]] [ 0.00289632]
145 0.0 128.918563843 [[ 0.37242796  0.59744195]] [ 0.0028947]
146 0.0 128.991836548 [[ 0.37222203  0.59720683]] [ 0.00291328]
147 0.0 128.942947388 [[ 0.37221964  0.59725335]] [ 0.0029119]
148 0.0 129.018218994 [[ 0.37222953  0.59748971]] [ 0.00289863]
149 0.0 128.96635437 [[ 0.37215925  0.59724392]] [ 0.0029111]
150 0.0 129.029327393 [[ 0.37167898  0.59777789]] [ 0.00291006]
151 0.0 129.028549194 [[ 0.37166593  0.59758238]] [ 0.00292053]
152 0.0 129.03944397 [[ 0.37161144  0.59711916]] [ 0.0029392]
153 0.0 128.974411011 [[ 0.37166708  0.59760692]] [ 0.00290708]
154 0.0 129.00177002 [[ 0.37137431  0.5973347 ]] [ 0.00293344]
155 0.0 128.928833008 [[ 0.37126042  0.59745944]] [ 0.00293505]
156 0.0 128.922058105 [[ 0.37115676  0.59750399]] [ 0.00294023]
157 0.0 128.931427002 [[ 0.37104174  0.59769911]] [ 0.00292858]
158 0.0 129.023498535 [[ 0.37106421  0.59748885]] [ 0.00293071]
159 0.0 128.931091309 [[ 0.37129344  0.59776994]] [ 0.00290435]
160 0.0 128.931015015 [[ 0.37126228  0.59799855]] [ 0.0028864]
161 0.0 128.98777771 [[ 0.37106828  0.59767806]] [ 0.00291693]
162 0.0 129.013977051 [[ 0.37100998  0.5977346 ]] [ 0.00291137]
163 0.0 129.050842285 [[ 0.37061796  0.59813126]] [ 0.00291852]
164 0.0 128.965438843 [[ 0.37047256  0.59807982]] [ 0.00292693]
165 0.0 128.977813721 [[ 0.3704239   0.59789633]] [ 0.00293128]
166 0.0 128.996917725 [[ 0.37026423  0.59796526]] [ 0.00293319]
167 0.0 128.953048706 [[ 0.37041396  0.59790435]] [ 0.00291817]
168 0.0 129.051467896 [[ 0.37054646  0.59747624]] [ 0.00292129]
169 0.0 129.001815796 [[ 0.37049385  0.59788682]] [ 0.00289212]
170 0.0 129.193893433 [[ 0.37039478  0.59854108]] [ 0.00286546]
171 0.0 129.044326782 [[ 0.36994972  0.5985627 ]] [ 0.00289845]
172 0.0 129.099197388 [[ 0.36957463  0.59861029]] [ 0.00292944]
173 0.0 129.004165649 [[ 0.36915464  0.59882673]] [ 0.00293991]
174 0.0 129.137313843 [[ 0.36919394  0.5991172 ]] [ 0.00292266]
175 0.0 129.030990601 [[ 0.368967    0.59929281]] [ 0.0029317]
176 0.0 128.986755371 [[ 0.36887077  0.59920901]] [ 0.00293132]
177 0.0 128.919158936 [[ 0.36861224  0.59947615]] [ 0.00293873]
178 0.0 128.931091309 [[ 0.36878886  0.59948977]] [ 0.00291802]
179 0.0 129.058166504 [[ 0.36853284  0.59904856]] [ 0.00295461]
180 0.0 129.009979248 [[ 0.36852206  0.59904342]] [ 0.00294784]
181 0.0 128.963363647 [[ 0.36860663  0.59893317]] [ 0.00294452]
182 0.0 129.04649353 [[ 0.36886134  0.59856405]] [ 0.00294589]
183 0.0 129.126617432 [[ 0.3688607   0.59944689]] [ 0.00289575]
184 0.0 129.020492554 [[ 0.36879574  0.59934046]] [ 0.00289443]
185 0.0 129.004852295 [[ 0.36826078  0.59932896]] [ 0.00293515]
186 0.0 129.069366455 [[ 0.36799872  0.59937272]] [ 0.00295154]
187 0.0 129.093353271 [[ 0.36813222  0.59997342]] [ 0.00291147]
188 0.0 129.157913208 [[ 0.36788541  0.60003179]] [ 0.0029221]
189 0.0 129.16519165 [[ 0.36769674  0.59959069]] [ 0.00295418]
190 0.0 129.11932373 [[ 0.36696122  0.59999098]] [ 0.00297497]
191 0.0 129.150375366 [[ 0.36734883  0.59973205]] [ 0.00295635]
192 0.0 129.176773071 [[ 0.36770963  0.59974048]] [ 0.00292731]
193 0.0 129.110931396 [[ 0.36772502  0.59995855]] [ 0.0029172]
194 0.0 129.133590698 [[ 0.36747734  0.59940277]] [ 0.00296548]
195 0.0 129.143798828 [[ 0.36792321  0.59916319]] [ 0.00293136]
196 0.0 129.084976196 [[ 0.36771901  0.59910295]] [ 0.00295393]
197 0.0 129.0650177 [[ 0.36779315  0.59926384]] [ 0.00293645]
198 0.0 129.105194092 [[ 0.36765568  0.59989751]] [ 0.00291679]
199 0.0 129.050521851 [[ 0.36738974  0.59955724]] [ 0.00293938]
200 0.0 129.138031006 [[ 0.36731831  0.59994245]] [ 0.00292018]
201 0.0 129.147018433 [[ 0.36757447  0.59939487]] [ 0.00292018]
202 0.0 129.103317261 [[ 0.36742254  0.59952686]] [ 0.00292541]
203 0.0 129.135040283 [[ 0.36719568  0.59985038]] [ 0.00292142]
204 0.0 129.051925659 [[ 0.36674151  0.59958854]] [ 0.00296132]
205 0.0 129.094985962 [[ 0.36688201  0.5998171 ]] [ 0.00293808]
206 0.0 129.08744812 [[ 0.36687588  0.59985864]] [ 0.00292737]
207 0.0 129.153305054 [[ 0.36666449  0.60021343]] [ 0.00292772]
208 0.0 129.096847534 [[ 0.36651123  0.60005117]] [ 0.00294135]
209 0.0 129.060821533 [[ 0.36678646  0.59986926]] [ 0.00292934]
210 0.0 129.137115479 [[ 0.36669176  0.59989282]] [ 0.00293522]
211 0.0 129.154067993 [[ 0.36651673  0.59963579]] [ 0.00295442]
212 0.0 129.130493164 [[ 0.36661789  0.59954505]] [ 0.00294753]
213 0.0 129.191345215 [[ 0.36654181  0.59964358]] [ 0.00294596]
214 0.0 129.316650391 [[ 0.36642897  0.59953776]] [ 0.00295726]
215 0.0 129.318572998 [[ 0.36664451  0.59933128]] [ 0.00294944]
216 0.0 129.217178345 [[ 0.36704793  0.59858592]] [ 0.00295255]
217 0.0 129.213821411 [[ 0.36694499  0.59869143]] [ 0.00295501]
218 0.0 129.259765625 [[ 0.36699911  0.5987925 ]] [ 0.00294248]
219 0.0 129.120864868 [[ 0.36699425  0.59879593]] [ 0.0029371]
220 0.0 129.176330566 [[ 0.3669124   0.59799299]] [ 0.00298117]
221 0.0 129.169052124 [[ 0.36706561  0.59799238]] [ 0.00296297]
222 0.0 129.195053101 [[ 0.36704938  0.59794334]] [ 0.00295489]
223 0.0 129.241073608 [[ 0.36736819  0.59747339]] [ 0.00294804]
224 0.0 129.299255371 [[ 0.36728495  0.59769216]] [ 0.00294132]
225 0.0 129.265731812 [[ 0.36720759  0.59823309]] [ 0.0029171]
226 0.0 129.383346558 [[ 0.36706124  0.59848617]] [ 0.00291784]
227 0.0 129.27166748 [[ 0.36669185  0.5980512 ]] [ 0.00296369]
228 0.0 129.360458374 [[ 0.36667323  0.5982137 ]] [ 0.00294739]
229 0.0 129.364593506 [[ 0.3667893  0.5981544]] [ 0.002938]
230 0.0 129.352294922 [[ 0.36645928  0.59797278]] [ 0.00296391]
231 0.0 129.406524658 [[ 0.36657756  0.59779253]] [ 0.00296325]
232 0.0 129.335128784 [[ 0.36649506  0.59808766]] [ 0.0029474]
233 0.0 129.419158936 [[ 0.36662139  0.59793429]] [ 0.00294495]
234 0.0 129.355300903 [[ 0.36651538  0.59776566]] [ 0.0029605]
235 0.0 129.409622192 [[ 0.36679201  0.59756646]] [ 0.00294148]
236 0.0 129.409927368 [[ 0.3668625   0.59732245]] [ 0.00295131]
237 0.0 129.469024658 [[ 0.36703139  0.59640897]] [ 0.00297425]
238 0.0 129.363647461 [[ 0.36721989  0.59696297]] [ 0.0029293]
239 0.0 129.351318359 [[ 0.36647317  0.59729561]] [ 0.00297404]
240 0.0 129.368148804 [[ 0.36644736  0.59689871]] [ 0.00298589]
241 0.0 129.364746094 [[ 0.36663564  0.59722231]] [ 0.00294955]
242 0.0 129.355224609 [[ 0.36656987  0.5965935 ]] [ 0.00298637]
243 0.0 129.393371582 [[ 0.3663969   0.59680174]] [ 0.00297839]
244 0.0 129.377426147 [[ 0.36674881  0.5972421 ]] [ 0.00292699]
245 0.0 129.399627686 [[ 0.36652547  0.59697784]] [ 0.00295231]
246 0.0 129.474853516 [[ 0.36626052  0.59667493]] [ 0.00298789]
247 0.0 129.417617798 [[ 0.36647724  0.59621899]] [ 0.00298715]
248 0.0 129.359848022 [[ 0.3666081   0.59604506]] [ 0.00297796]
249 0.0 129.426040649 [[ 0.36683119  0.59598868]] [ 0.0029673]
250 0.0 129.33241272 [[ 0.36685682  0.59607608]] [ 0.00295932]
251 0.0 129.346359253 [[ 0.36703345  0.59601245]] [ 0.00294354]
252 0.0 129.253173828 [[ 0.3668126   0.59553777]] [ 0.00298214]
253 0.0 129.357803345 [[ 0.36670033  0.59558903]] [ 0.00298343]
254 0.0 129.247894287 [[ 0.36689946  0.59506241]] [ 0.00298889]
255 0.0 129.33442688 [[ 0.36673281  0.59516085]] [ 0.00298445]
256 0.0 129.273803711 [[ 0.36668188  0.59505227]] [ 0.00299164]
257 0.0 129.286346436 [[ 0.36631589  0.59595483]] [ 0.00296601]
258 0.0 129.393066406 [[ 0.36628746  0.59567822]] [ 0.0029736]
259 0.0 129.409851074 [[ 0.36616335  0.59594904]] [ 0.00296785]
260 0.0 129.362075806 [[ 0.36622333  0.59568766]] [ 0.00296886]
261 0.0 129.379348755 [[ 0.36592627  0.59594563]] [ 0.00297201]
262 0.0 129.323989868 [[ 0.36581763  0.59567005]] [ 0.00298643]
263 0.0 129.428741455 [[ 0.36573761  0.59537054]] [ 0.00300049]
264 0.0 129.43522644 [[ 0.3660147   0.59515684]] [ 0.00298546]
265 0.0 129.382293701 [[ 0.36589321  0.59551573]] [ 0.00297341]
266 0.0 129.434555054 [[ 0.3658236   0.59584906]] [ 0.00296564]
267 0.0 129.413452148 [[ 0.36576478  0.5955787 ]] [ 0.00298345]
268 0.0 129.452514648 [[ 0.3658134   0.59530979]] [ 0.00298949]
269 0.0 129.392440796 [[ 0.36588117  0.59465622]] [ 0.00301198]
270 0.0 129.482803345 [[ 0.36586804  0.59506021]] [ 0.00298441]
271 0.0 129.430526733 [[ 0.36583808  0.5946777 ]] [ 0.00299493]
272 0.0 129.52645874 [[ 0.36570326  0.59494811]] [ 0.00298542]
273 0.0 129.476577759 [[ 0.36577358  0.59485843]] [ 0.00298311]
274 0.0 129.598220825 [[ 0.36603395  0.59387732]] [ 0.00300271]
275 0.0 129.487640381 [[ 0.3661678   0.59403057]] [ 0.00298606]
276 0.0 129.531448364 [[ 0.36594996  0.59420898]] [ 0.00298424]
277 0.0 129.512252808 [[ 0.36591026  0.59356749]] [ 0.00300937]
278 0.0 129.563247681 [[ 0.36596078  0.59348945]] [ 0.00300612]
279 0.0 129.506851196 [[ 0.36574507  0.59352578]] [ 0.00301614]
280 0.0 129.524215698 [[ 0.36602243  0.59382677]] [ 0.00297631]
281 0.0 129.550354004 [[ 0.36614667  0.59323045]] [ 0.00299063]
282 0.0 129.488265991 [[ 0.36608478  0.59287743]] [ 0.0030056]
283 0.0 129.475952148 [[ 0.36634842  0.59256625]] [ 0.00299526]
284 0.0 129.55506897 [[ 0.36633244  0.59248964]] [ 0.00299511]
285 0.0 129.463882446 [[ 0.36650764  0.59242237]] [ 0.00297891]
286 0.0 129.583618164 [[ 0.36640442  0.59214359]] [ 0.00300369]
287 0.0 129.563064575 [[ 0.36678147  0.59134037]] [ 0.00301214]
288 0.0 129.628890991 [[ 0.36651433  0.59153775]] [ 0.00302042]
289 0.0 129.606918335 [[ 0.36652805  0.59202953]] [ 0.00299264]
290 0.0 129.573547363 [[ 0.3663155   0.59203083]] [ 0.00299899]
291 0.0 129.682647705 [[ 0.36614706  0.59247764]] [ 0.00299029]
292 0.0 129.675552368 [[ 0.36591388  0.59215301]] [ 0.00301429]
293 0.0 129.624450684 [[ 0.36584536  0.59237573]] [ 0.00300571]
294 0.0 129.718490601 [[ 0.36587519  0.59250261]] [ 0.002994]
295 0.0 129.707321167 [[ 0.36584296  0.59227786]] [ 0.00300724]
296 0.0 129.656829834 [[ 0.36583066  0.59185925]] [ 0.0030265]
297 0.0 129.621826172 [[ 0.36585217  0.59233968]] [ 0.00299816]
298 0.0 129.686721802 [[ 0.36589665  0.59224447]] [ 0.00299296]
299 0.0 129.77456665 [[ 0.36565435  0.59250126]] [ 0.00299317]
300 0.0 129.748199463 [[ 0.3655837  0.592026 ]] [ 0.00301351]
301 0.0 129.768203735 [[ 0.36572942  0.59238103]] [ 0.00298356]
302 0.0 129.707015991 [[ 0.36579325  0.59250059]] [ 0.00296696]
303 0.0 129.564178467 [[ 0.36531534  0.59214263]] [ 0.0030189]
304 0.0 129.734146118 [[ 0.36516225  0.59224641]] [ 0.00302108]
305 0.0 129.702423096 [[ 0.36542902  0.59210476]] [ 0.00300309]
306 0.0 129.787139893 [[ 0.36593971  0.59160446]] [ 0.00298417]
307 0.0 129.818389893 [[ 0.36569358  0.5916045 ]] [ 0.00300596]
308 0.0 129.818466187 [[ 0.3653618   0.59152276]] [ 0.00303178]
309 0.0 129.811950684 [[ 0.36560174  0.59130093]] [ 0.00301187]
310 0.0 129.937942505 [[ 0.36548334  0.59148906]] [ 0.00300657]
311 0.0 129.929428101 [[ 0.36558801  0.59132254]] [ 0.00300509]
312 0.0 129.995315552 [[ 0.36549275  0.59129864]] [ 0.0030074]
313 0.0 129.960296631 [[ 0.36538175  0.59104409]] [ 0.00302275]
314 0.0 129.829864502 [[ 0.36526312  0.59122911]] [ 0.00301706]
315 0.0 129.988143921 [[ 0.36533881  0.59094243]] [ 0.00303694]
316 0.0 129.977401733 [[ 0.36517701  0.5903076 ]] [ 0.00306745]
317 0.0 129.975402832 [[ 0.3652906   0.59035126]] [ 0.00304737]
318 0.0 129.929458618 [[ 0.36541656  0.59033845]] [ 0.00303536]
319 0.0 130.019439697 [[ 0.36586394  0.59017359]] [ 0.00300621]
320 0.0 130.029464722 [[ 0.36577127  0.59004345]] [ 0.00301819]
321 0.0 130.01789856 [[ 0.36557895  0.58976912]] [ 0.00304696]
322 0.0 130.093917847 [[ 0.36559555  0.5905059 ]] [ 0.00299964]
323 0.0 130.058563232 [[ 0.36534338  0.58952261]] [ 0.00306186]
324 0.0 130.048171997 [[ 0.36499417  0.58959412]] [ 0.00307409]
325 0.0 129.91960144 [[ 0.36553323  0.58933778]] [ 0.00305058]
326 0.0 130.049591064 [[ 0.36545315  0.5890781 ]] [ 0.00305912]
327 0.0 130.0574646 [[ 0.36565443  0.58898405]] [ 0.00305186]
328 0.0 130.084152222 [[ 0.36577027  0.58939715]] [ 0.00301683]
329 0.0 129.969848633 [[ 0.36576216  0.58920954]] [ 0.00302424]
330 0.0 130.105102539 [[ 0.36578318  0.58909534]] [ 0.00302375]
331 0.0 130.08682251 [[ 0.36551807  0.58898984]] [ 0.0030394]
332 0.0 130.050033569 [[ 0.36515615  0.5894476 ]] [ 0.00303656]

\\srv-file.brml.tum.de\nthome\cwolf\code\climin\climin\util.py:150: UserWarning: Argument named f is not expected by <class 'climin.adam.Adam'>
  % (i, klass))

In [86]:
train_result_params = m.parameters.data.copy()
m.parameters.data = info['best_pars']
m.score(VX)
Out[86]:
garray(128.7577667236328)
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [87]:
f_z_init_sample = m.function(['inpt'], m.init_recog.sample(), numpy_result=True)
f_z_sample = m.function(['inpt'], m.hmc_sampler.output, numpy_result=True)
f_gen = m.function([m.gen.inpt], m.gen.sample(), numpy_result=True)
f_gen_rate = m.function([m.gen.inpt], m.gen.rate, numpy_result=True)
f_joint_nll = m.function(['inpt'], m.joint_nll, numpy_result=True)
In [88]:
curr_pos = T.matrix('current_position')
curr_vel = T.matrix('current_velocity')
norm_noise = T.matrix('normal_noise')
unif_noise = T.vector('uniform_noise')

new_sampled_vel = m.hmc_sampler.kin_energy.sample(norm_noise)
updated_vel = m.hmc_sampler.partial_vel_constant * curr_vel + m.hmc_sampler.partial_vel_complement * new_sampled_vel
performed_hmc_steps = m.hmc_sampler.perform_hmc_steps(curr_pos, curr_vel)
hmc_step = m.hmc_sampler.hmc_step(curr_pos, curr_vel, np.float32(0), norm_noise, unif_noise)
lf_step_results = m.hmc_sampler.simulate_dynamics(curr_pos, curr_vel, return_full_list=True)

f_pot_en = m.function(['inpt', curr_pos], m.hmc_sampler.eval_pot_energy(curr_pos), numpy_result=True)
f_kin_en = m.function(['inpt', curr_vel], m.kin_energy.nll(curr_vel).sum(-1), numpy_result=True)
f_perform_hmc_steps = m.function(['inpt', curr_pos, curr_vel], 
                                T.concatenate([performed_hmc_steps[0], performed_hmc_steps[1]], axis=1))
f_hmc_step = m.function(['inpt', curr_pos, curr_vel, norm_noise, unif_noise], 
                        T.concatenate([hmc_step[0], hmc_step[1]],axis=1), on_unused_input='warn')
f_kin_energy_sample_from_noise = m.function(['inpt', norm_noise], new_sampled_vel)
f_updated_vel_from_noise = m.function(['inpt', curr_vel, norm_noise], updated_vel)
f_perform_lf_steps = m.function(['inpt', curr_pos, curr_vel],
                               T.concatenate([lf_step_results[0], lf_step_results[1]], axis=0))
In [89]:
f_z_init_mean = m.function(['inpt'], m.init_recog.mean, numpy_result=True)
f_z_init_var = m.function(['inpt'], m.init_recog.var, numpy_result=True)

f_v_init_var = m.function(['inpt'], T.extra_ops.cpu_contiguous(m.kin_energy.var), numpy_result=True)

full_sample = m.hmc_sampler.sample_with_path()
f_full_sample = m.function(['inpt'], T.concatenate([full_sample[0], full_sample[1]], axis=1))
In [90]:
final_pos = T.matrix('final_pos')
final_vel = T.matrix('final_vel')
inpt_replacements = {m.final_vel_model_inpt['position']: final_pos,
                     m.final_vel_model_inpt['time']: T.cast(m.hmc_sampler.n_hmc_steps, dtype='float32')}

final_vel_model_var = clone(m.final_vel_model.var, replace=inpt_replacements)
final_vel_model_mean = clone(m.final_vel_model.mean, replace=inpt_replacements)
final_vel_model_nll = clone(m.final_vel_model.nll(final_vel).sum(-1), replace=inpt_replacements)

f_v_final_var = m.function(['inpt', final_pos], final_vel_model_var, numpy_result=True)
f_v_final_mean = m.function(['inpt', final_pos], final_vel_model_mean, numpy_result=True)
f_v_final_model_nll = m.function(['inpt', final_pos, final_vel], final_vel_model_nll, numpy_result=True)

f_kin_energy_nll = m.function(['inpt'], m.kin_energy.expected_nll, numpy_result=True)
In [91]:
f_init_recog_nll = m.function(['inpt'], m.init_recog.expected_nll.sum(-1), numpy_result=True)
In [92]:
print f_init_recog_nll(VX).mean()
init_var = f_z_init_var(VX)
print init_var.mean()
print init_var.max()
print init_var.min()
-4.16031
0.00134758
0.0745301
3.34876e-05

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [93]:
fig, axs = plt.subplots(2, 3, figsize=(27, 18))

### Original data

O = (X_no_bin_np[:64])[:, :784].astype('float32')
img = tile_raster_images(O, image_dims, (8, 8), (1, 1))
axs[0, 0].imshow(img, cmap=cm.binary)

O2 = (X_np[:64])[:, :784].astype('float32')
img = tile_raster_images(O2, image_dims, (8, 8), (1, 1))
axs[1, 0].imshow(img, cmap=cm.binary)

### Reconstruction

#z_sample = f_z_sample((X[:64]))
z_init_sample = cast_array_to_local_type(f_z_init_sample((X[:64])))
z_sample = f_perform_hmc_steps((X[:64]), 
                               z_init_sample, 
                               f_kin_energy_sample_from_noise((X[:64]), 
                                                              cast_array_to_local_type(np.random.normal(size=(64, m.n_latent)).astype('float32')))
                               )[-1, :64, :]

R = f_gen_rate(z_sample)[:, :784].astype('float32')
img = tile_raster_images(R, image_dims, (8, 8), (1, 1))
axs[0, 1].imshow(img, cmap=cm.binary)

Rinit = f_gen_rate(z_init_sample)[:, :784].astype('float32')
img = tile_raster_images(Rinit, image_dims, (8, 8), (1, 1))
axs[0, 2].imshow(img, cmap=cm.binary)

R2 = f_gen(z_sample)[:, :784].astype('float32')
img = tile_raster_images(R2, image_dims, (8, 8), (1, 1))
axs[1, 1].imshow(img, cmap=cm.binary)

Rinit2 = f_gen(z_init_sample)[:, :784].astype('float32')
img = tile_raster_images(Rinit2, image_dims, (8, 8), (1, 1))
axs[1, 2].imshow(img, cmap=cm.binary)
Out[93]:
<matplotlib.image.AxesImage at 0x2630e6a0>
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [94]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))

prior_sample = cast_array_to_local_type(np.random.randn(64, m.n_latent))

S = f_gen_rate(prior_sample)[:, :784].astype('float32')
img = tile_raster_images(S, image_dims, (8, 8), (1, 1))
axs[0].imshow(img, cmap=cm.binary)

S2 = f_gen(prior_sample)[:, :784].astype('float32')
img = tile_raster_images(S2, image_dims, (8, 8), (1, 1))
axs[1].imshow(img, cmap=cm.binary)

#S3 = f_gen_rate(prior_sample)[:, :784].astype('float32')
#img = tile_raster_images(S, image_dims, (8, 8), (1, 1))
#axs[2, 2].imshow(img, cmap=cm.nipy_spectral)
Out[94]:
<matplotlib.image.AxesImage at 0x7439d5f8>
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [95]:
# TODO: Axis titles, plot title, make this work if one selects two dimensions out of more than two (i.e. if n_latent>2)

from scipy.stats import norm as normal_distribution

unit_interval_positions = np.linspace(0.025, 0.975, 20)
positions = normal_distribution.ppf(unit_interval_positions)
print unit_interval_positions
print positions

latent_array = np.zeros((400, 2))

latent_array[:, 1] = -np.repeat(positions, 20)  # because images are filled top -> bottom, left -> right (row by row)
latent_array[:, 0] = np.tile(positions, 20)
        
fig, axs = plt.subplots(1, 1, figsize=(24, 24))

F = f_gen_rate(cast_array_to_local_type(latent_array)).astype('float32')

img = tile_raster_images(F, image_dims, (20, 20), (1, 1))
#axs.imshow(img, cmap=cm.nipy_spectral)
axs.imshow(img, cmap=cm.binary)
[ 0.025  0.075  0.125  0.175  0.225  0.275  0.325  0.375  0.425  0.475
  0.525  0.575  0.625  0.675  0.725  0.775  0.825  0.875  0.925  0.975]
[-1.95996398 -1.43953147 -1.15034938 -0.93458929 -0.75541503 -0.59776013
 -0.45376219 -0.31863936 -0.18911843 -0.06270678  0.06270678  0.18911843
  0.31863936  0.45376219  0.59776013  0.75541503  0.93458929  1.15034938
  1.43953147  1.95996398]

Out[95]:
<matplotlib.image.AxesImage at 0x27093be0>
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [96]:
L = f_z_sample(X)
L_init = f_z_init_sample(X)
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [97]:
dim1 = 0
dim2 = 1
In [98]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))
axs[0].scatter(L[:, dim1], L[:, dim2], c=Z[:].argmax(1), lw=0, s=5, alpha=.2)
axs[1].scatter(L_init[:, dim1], L_init[:, dim2], c=Z[:].argmax(1), lw=0, s=5, alpha=.2)

cax = fig.add_axes([0.95, 0.2, 0.02, 0.6])
cax.scatter(np.repeat(0, 10), np.arange(10), c=np.arange(10), lw=0, s=300)
cax.set_xlim(-0.1, 0.1)
cax.set_ylim(-0.5, 9.5)
plt.yticks(np.arange(10))
plt.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off')
cax.tick_params(axis='y', colors='white')
for tick in cax.yaxis.get_major_ticks():
    tick.label.set_fontsize(14)
    tick.label.set_color('black')
    
cax.spines['bottom'].set_color('white')
cax.spines['top'].set_color('white') 
cax.spines['right'].set_color('white')
cax.spines['left'].set_color('white')

axs[0].set_title('After HMC steps')
axs[1].set_title('Initial recognition model')

axs[0].set_xlim(-3, 3)
axs[0].set_ylim(-3, 3)
axs[1].set_xlim(-3, 3)
axs[1].set_ylim(-3, 3)
Out[98]:
(-3, 3)
In [99]:
fig, axs = plt.subplots(4, 5, figsize=(20, 16))
colors = cm.jet(np.linspace(0, 1, 10))
for i in range(5):
    axs[0, i].scatter(L_init[Z[:].argmax(1) == i, dim1], L_init[Z[:].argmax(1) == i, dim2], c=colors[i], lw=0, s=5, alpha=.2)
    axs[1, i].scatter(L[Z[:].argmax(1) == i, dim1], L[Z[:].argmax(1) == i, dim2], c=colors[i], lw=0, s=5, alpha=.2)
    axs[0, i].set_title(str(i) + ' before HMC')
    axs[1, i].set_title(str(i) + ' after HMC')
    axs[2, i].scatter(L_init[Z[:].argmax(1) == (5+i), dim1], L_init[Z[:].argmax(1) == (5+i), dim2], c=colors[5+i], lw=0, s=5, alpha=.2)
    axs[3, i].scatter(L[Z[:].argmax(1) == (5+i), dim1], L[Z[:].argmax(1) == (5+i), dim2], c=colors[5+i], lw=0, s=5, alpha=.2)
    axs[2, i].set_title(str(5+i) + ' before HMC')
    axs[3, i].set_title(str(5+i) + ' after HMC')
    for j in range(4):
        axs[j, i].set_xlim(-3, 3)
        axs[j, i].set_ylim(-3, 3)
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [100]:
X_index = 0  # index=0 -> 5, index=1 -> 0, index=2 -> 4, index=3 -> 1, index=24 -> underlined 1, index=39 -> ugly 6
num_repeats = 1000

fig, axs = plt.subplots(1, 2, figsize=(6, 3))
img = tile_raster_images(np.array([X[X_index, :]]), image_dims, (1, 1), (1, 1))
axs[0].imshow(img, cmap=cm.binary)
img = tile_raster_images(np.array([X_no_bin[X_index, :]]), image_dims, (1, 1), (1, 1))
axs[1].imshow(img, cmap=cm.binary)
Out[100]:
<matplotlib.image.AxesImage at 0x7c5d7208>
In [101]:
repeated_X = cast_array_to_local_type(np.tile(np.array([X[X_index, :]]), (num_repeats, 1)).astype('float32'))

full_sample = f_full_sample(repeated_X).astype('float32')
z_samples = full_sample[:, :num_repeats, :]
v_samples = full_sample[:, num_repeats:, :]

z_sample_final_mean = z_samples[m.n_hmc_steps, :, :].mean(axis=0)
z_sample_final_std = z_samples[m.n_hmc_steps, :, :].std(axis=0)

single_X = cast_array_to_local_type(np.array([X[X_index, :]]).astype('float32'))
init_mean = f_z_init_mean(single_X)[0]
init_var = f_z_init_var(single_X)[0]

init_vel_var = f_v_init_var(single_X)[0]

print 'Posterior distribution statistics'
print
print 'Initial model: - Mean: ' + str(init_mean)
print '               - Var:  ' + str(init_var)
print
print 'Full HVI model: - Mean: ' + str(z_sample_final_mean)
print '                - Var:  ' + str(z_sample_final_std ** 2)
print
print 'Velocity model variance: ' + str(init_vel_var)
Posterior distribution statistics

Initial model: - Mean: [ 0.34112886 -0.18799755]
               - Var:  [ 0.00064511  0.00014962]

Full HVI model: - Mean: [ 0.37444857 -0.18244392]
                - Var:  [ 0.00056328  0.0034712 ]

Velocity model variance: [ 0.40236035  0.56585824]

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [102]:
dim1 = 0
dim2 = 1
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [110]:
resolution = 201
lower_dim1_limit = z_sample_final_mean[dim1] - 0.2
upper_dim1_limit = z_sample_final_mean[dim1] + 0.2
lower_dim2_limit = z_sample_final_mean[dim2] - 0.2
upper_dim2_limit = z_sample_final_mean[dim2] + 0.2

number_images_per_axis = 11
latent_array = np.zeros((number_images_per_axis**2, 2))
gap_between_images = (resolution - 1)//(number_images_per_axis - 1)
indices_for_images = np.arange(0, resolution, gap_between_images)

pot_energy_matrix = np.zeros((resolution, resolution), dtype='float32')
x = np.linspace(lower_dim1_limit, upper_dim1_limit, resolution)
y = np.linspace(lower_dim2_limit, upper_dim2_limit, resolution)
for i in range(resolution):
    for j in range(resolution):
        #pos_array = f_z_init_mean(single_X)
        pos_array = np.array([z_sample_final_mean])
        pos_array[0, dim1] = x[i]
        pos_array[0, dim2] = y[j]
        pos_array_of_local_type = cast_array_to_local_type(pos_array)
        pot_energy_matrix[j, i] = f_pot_en(single_X, pos_array_of_local_type)[0]
        if i in indices_for_images and j in indices_for_images:
            latent_array[(i//gap_between_images) + (number_images_per_axis - 1 - j//gap_between_images)*number_images_per_axis , :] = pos_array[0, :]

        
print 'Minimum potential energy (at grid points): ' + str(pot_energy_matrix.min())
print 'Maximum potential energy (at grid points): ' + str(pot_energy_matrix.max())

fig, axs = plt.subplots(1, 2, figsize=(18, 9))
CS = axs[0].contour(x, y, pot_energy_matrix, 40)
plt.clabel(CS, inline=1, fmt='%1.0f', fontsize=10)
axs[0].set_title('Potential energy surface')

F = f_gen_rate(cast_array_to_local_type(latent_array))
img = tile_raster_images(F, image_dims, (number_images_per_axis, number_images_per_axis), (1, 1))
#axs.imshow(img, cmap=cm.nipy_spectral)
axs[1].imshow(img, cmap=cm.binary)
plt.show()
Minimum potential energy (at grid points): 179.766
Maximum potential energy (at grid points): 299.986

In [104]:
resolution = 200
underlying_variance = f_v_init_var(single_X)
velocity_range_for_images = 10.0 * np.sqrt(underlying_variance[0, :])
lower_dim1_limit = np.around(- velocity_range_for_images[dim1])
upper_dim1_limit = np.around(  velocity_range_for_images[dim1])
lower_dim2_limit = np.around(- velocity_range_for_images[dim2])
upper_dim2_limit = np.around(  velocity_range_for_images[dim2])

kin_energy_matrix = np.zeros((resolution, resolution), dtype='float32')
kin_x = np.linspace(lower_dim1_limit, upper_dim1_limit, resolution)
kin_y = np.linspace(lower_dim2_limit, upper_dim2_limit, resolution)
for i in range(resolution):
    for j in range(resolution):
        vel_array = np.zeros((1, m.n_latent)).astype('float32')
        vel_array[0, dim1] = kin_x[i]
        vel_array[0, dim2] = kin_y[j]
        vel_array_of_local_type = cast_array_to_local_type(vel_array)
        kin_energy_matrix[j, i] = f_kin_en(single_X, vel_array_of_local_type)

print 'Minimum kinetic energy (at grid points): ' + str(kin_energy_matrix.min())
print 'Maximum kinetic energy (at grid points): ' + str(kin_energy_matrix.max())

fig, ax = plt.subplots(1, 1, figsize=(9, 9))
CS = ax.contour(kin_x, kin_y, kin_energy_matrix)
plt.axes().set_aspect('equal', 'datalim')
plt.clabel(CS, inline=1, fmt='%1.1f', fontsize=10)
ax.set_title('Kinetic energy surface')
plt.show()
Minimum kinetic energy (at grid points): 1.10053
Maximum kinetic energy (at grid points): 102.385

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [115]:
fig, axs = plt.subplots(m.n_hmc_steps + 1, 3, figsize=(18, (m.n_hmc_steps + 1) * 6))
colors = cm.jet(np.linspace(0, 1, 10))

#contour_levels = (198, 200, 202, 204, 206, 208, 210)
#contour_levels = (120, 130, 140, 150, 160, 180, 200, 240, 280)
#contour_levels = (100, 102, 104, 106, 108, 110, 115, 120, 125, 130)
#contour_levels = (400, 402, 404, 406, 408, 410, 412, 416, 420)
#contour_levels = (106, 108, 110, 112, 114, 116, 118, 120, 124, 128)
contour_levels = (160, 165, 170, 175, 180, 185, 190, 195, 200, 210, 220, 230, 240, 250, 270, 300)
#contour_levels = (174, 175, 176, 177, 178, 180, 182, 184, 186, 190, 200)
#contour_levels = (59, 61, 63, 65, 67, 69, 71, 73, 75, 80, 85, 90)

vel_contour_levels = np.linspace(2.0, 70.0, 18)
#CS0 = axs[0, 0].contourf(x, y, pot_energy_matrix, np.linspace(155, 240, 500))

def colour_for_z_samples(samples):
    mean = samples.mean(axis=0)
    mean1 = mean[dim1]
    mean2 = mean[dim2]
    colour = np.zeros_like(samples[:, 0])
    colour[np.logical_and(samples[:, dim1] < mean1,  samples[:, dim2] < mean2)] = 0
    colour[np.logical_and(samples[:, dim1] < mean1,  samples[:, dim2] >= mean2)] = 2
    colour[np.logical_and(samples[:, dim1] >= mean1, samples[:, dim2] < mean2)] = 4
    colour[np.logical_and(samples[:, dim1] >= mean1, samples[:, dim2] >= mean2)] = 7
    colour[((samples[:, dim1] - mean1) ** 2 + (samples[:, dim2] - mean2) ** 2) < 1e-5] = 9
    return colour.astype('int32')

colour = colour_for_z_samples(z_samples[m.n_hmc_steps,:,:])
print v_samples[m.n_hmc_steps, colour == 0, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 2, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 4, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 7, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 9, :].mean(axis=0)
print v_samples[m.n_hmc_steps, colour == 0, :].var(axis=0)
print v_samples[m.n_hmc_steps, colour == 2, :].var(axis=0)
print v_samples[m.n_hmc_steps, colour == 4, :].var(axis=0)
print v_samples[m.n_hmc_steps, colour == 7, :].var(axis=0)
print v_samples[m.n_hmc_steps, colour == 9, :].var(axis=0)

for i in range(m.n_hmc_steps + 1):
    CS = axs[i, 0].contour(x, y, pot_energy_matrix, contour_levels)
    plt.clabel(CS, inline=1, fmt='%1.0f', fontsize=10)
    axs[i, 0].scatter(z_samples[i,:,dim1], z_samples[i,:,dim2], c=colors[colour_for_z_samples(z_samples[i,:,:])], s=20, alpha=.3, lw=0)
    
    CS_vel = axs[i, 1].contour(kin_x, kin_y, kin_energy_matrix, vel_contour_levels)
    plt.clabel(CS_vel, inline=1, fmt='%1.1f', fontsize=10)
    axs[i, 1].scatter(v_samples[i,:,dim1], v_samples[i,:,dim2], c=colors[colour_for_z_samples(z_samples[i,:,:])], s=20, alpha=.3, lw=0)
    
    pot_energy_distrib = f_pot_en(repeated_X, cast_array_to_local_type(z_samples[i, :, :]))
    if i == 0:
        max_x_value_for_hist = pot_energy_distrib.max() + 5
        min_x_value_for_hist = np.floor(pot_energy_matrix.min()) -5
    pot_energy_distrib_mean = pot_energy_distrib.mean()
    axs[i, 2].hist(pot_energy_distrib, 30, normed=1, range=(min_x_value_for_hist, max_x_value_for_hist))
    axs[i, 2].autoscale(enable=False, axis='both')
    axs[i, 2].axvline(pot_energy_distrib_mean, color='r', linestyle='dashed', linewidth=2)
    axs[i, 2].set_xlim(min_x_value_for_hist, max_x_value_for_hist)
    axs[i, 2].text(pot_energy_distrib_mean + 1.0, 0.8*axs[i, 2].get_ylim()[1], 'Mean: ' + str(pot_energy_distrib_mean))
    axs[i, 1].set_xlim(-velocity_range_for_images[dim1], velocity_range_for_images[dim1])
    axs[i, 1].set_ylim(-velocity_range_for_images[dim2], velocity_range_for_images[dim2])
    axs[i, 1].set_aspect('equal', 'datalim')
    axs[i, 0].set_aspect('equal', 'datalim')

axs[0, 0].scatter(f_z_init_mean(single_X)[0, dim1], f_z_init_mean(single_X)[0, dim2], c='black', s=20)

plt.show()
[ 0.37442935 -0.3596921 ]
[ 0.64236486  1.79007053]
[ 0.6661306 -0.4157398]
[ 0.53253555  1.78914237]
[ nan  nan]
[ 0.42884281  1.16761553]
[ 0.58845699  1.29838049]
[ 0.33406088  0.86715931]
[ 0.41920319  1.18243825]
[ nan  nan]

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [112]:
np.random.seed(1)

velocity_noise = cast_array_to_local_type(np.random.normal(size=(m.n_hmc_steps, 1, m.n_latent)))
#velocity_noise = np.zeros_like(velocity_noise)

init_pos = f_z_init_mean(single_X) # + np.array([0.0, 0.1])
init_vel = f_kin_energy_sample_from_noise(single_X, velocity_noise[0])

num_vels_per_hmc = (m.n_lf_steps + 2)
position_array = np.zeros((m.n_hmc_steps * m.n_lf_steps + 1, m.n_latent))
position_array[0] = init_pos
velocity_array = np.zeros((m.n_hmc_steps * num_vels_per_hmc, m.n_latent))
velocity_array[0] = ma.assert_numpy(init_vel)

for hmc_num in range(m.n_hmc_steps):
    if hmc_num == 0:
        curr_pos = cast_array_to_local_type(init_pos)
        curr_vel = init_vel
    else:
        curr_vel = f_updated_vel_from_noise(single_X, curr_vel, velocity_noise[hmc_num])
        velocity_array[hmc_num * (m.n_lf_steps + 2)] = ma.assert_numpy(curr_vel)
    
    lf_step_results = f_perform_lf_steps(single_X, curr_pos, curr_vel)
    pos_steps = lf_step_results[:m.n_lf_steps]
    vel_half_steps_and_final = lf_step_results[m.n_lf_steps:]
    final_vel = lf_step_results[-1]
    final_pos = pos_steps[-1]
    
    position_array[hmc_num * m.n_lf_steps + 1: (hmc_num + 1)*m.n_lf_steps + 1] = ma.assert_numpy(pos_steps[:, 0, :])
    velocity_array[hmc_num * num_vels_per_hmc + 1: (hmc_num + 1) * num_vels_per_hmc] = ma.assert_numpy(vel_half_steps_and_final[:, 0, :])
    
    curr_pos = final_pos
    curr_vel = final_vel
In [114]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))
step_color = cm.jet(np.linspace(0, 1, position_array.shape[0]))
CS = axs[0].contour(x, y, pot_energy_matrix, 40)
CS_vel = axs[1].contour(kin_x, kin_y, kin_energy_matrix, vel_contour_levels)
hmc_step_indices = np.arange(0, position_array.shape[0], m.n_lf_steps)
size_array = 40*np.ones((position_array.shape[0],))
size_array[hmc_step_indices] = 100
axs[0].scatter(position_array[:, dim1], position_array[:, dim2], c=step_color, lw=1, s=size_array)
axs[1].set_color_cycle(step_color)

for hmc_num in range(m.n_hmc_steps):
    curr_vel_range = np.arange(num_vels_per_hmc * hmc_num, num_vels_per_hmc * (hmc_num + 1) - 2)
    init_vel_ind = hmc_num * num_vels_per_hmc
    final_vel_ind = (hmc_num + 1) * num_vels_per_hmc - 1
    curr_index = hmc_step_indices[hmc_num]
    next_index = hmc_step_indices[hmc_num + 1]
    for j in curr_vel_range:
        axs[1].plot(velocity_array[j:j+2, dim1], velocity_array[j:j+2, dim2], lw=2)
    axs[1].scatter(velocity_array[init_vel_ind, dim1], velocity_array[init_vel_ind, dim2], c=step_color[curr_index], lw=0, s=100)
    axs[1].scatter(velocity_array[final_vel_ind, dim1], velocity_array[final_vel_ind, dim2], c=step_color[next_index], lw=0, s=100)

for hmc_num in range(m.n_hmc_steps):
    final_vel_ind = (hmc_num + 1) * num_vels_per_hmc - 1
    next_index = hmc_step_indices[hmc_num + 1]
    axs[1].plot(velocity_array[final_vel_ind-1:final_vel_ind+1, dim1], velocity_array[final_vel_ind-1:final_vel_ind+1, dim2], lw=2, c=step_color[next_index])

axs[0].set_aspect('equal', 'datalim')
axs[1].set_aspect('equal', 'datalim')
Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [118]:
# TODO: Do this for the intermediate aux vel model steps

variation_start = z_sample_final_mean - 2*z_sample_final_std
variation_end = z_sample_final_mean + 2*z_sample_final_std
print variation_start
print variation_end
final_vel_model_mean_output = np.zeros((m.n_latent, num_repeats, m.n_latent))
final_vel_model_var_output = np.zeros((m.n_latent, num_repeats, m.n_latent))

for variation_dim in range(m.n_latent):
    z_variation = np.linspace(variation_start[variation_dim], variation_end[variation_dim], num_repeats)
    sample_array = np.tile(z_sample_final_mean, (num_repeats, 1))
    sample_array[:, variation_dim] = z_variation
    final_vel_model_mean_output[variation_dim] = f_v_final_mean(repeated_X, cast_array_to_local_type(sample_array))
    final_vel_model_var_output[variation_dim] = f_v_final_var(repeated_X, cast_array_to_local_type(sample_array))
[ 0.32698157 -0.30027768]
[ 0.42191556 -0.06461015]

In [117]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))
axs[0].scatter(final_vel_model_mean_output[:, :, dim1], 
           final_vel_model_mean_output[:, :, dim2],  
           c=np.transpose(np.tile(np.linspace(0,m.n_latent-1,m.n_latent), (num_repeats, 1))), 
           lw=0, s=5)
axs[1].scatter(final_vel_model_var_output[:, :, dim1], 
           final_vel_model_var_output[:, :, dim2],  
           c=np.transpose(np.tile(np.linspace(0,m.n_latent-1,m.n_latent), (num_repeats, 1))), 
           lw=0, s=5)

plt.show()
In [119]:
final_z_samples = cast_array_to_local_type(z_samples[m.n_hmc_steps, :, :])
final_v_samples = cast_array_to_local_type(v_samples[m.n_hmc_steps, :, :])
final_vel_mean = f_v_final_mean(repeated_X, final_z_samples)
final_vel_var = f_v_final_var(repeated_X, final_z_samples)
final_vel_nll = f_v_final_model_nll(repeated_X, final_z_samples, final_v_samples)
In [120]:
print f_kin_energy_nll(single_X).sum(-1)

print final_vel_nll.mean()
print final_vel_nll.min()
print final_vel_nll.max()
[ 2.09796762]
3.43918
1.60792
21.6117

Node Commands Syntax: node {operator} [options] [arguments] Parameters: /? or /help - Display this help message. list - List nodes or node history or the cluster listcores - List cores on the cluster view - View properties of a node online - Set nodes or node to online state offline - Set one or more nodes to the offline state For more information about HPC command-line tools, see http://go.microsoft.com/fwlink/?LinkId=120724.
In [None]:
fig, axs = plt.subplots(4, 2, figsize=(18, 36))
# TODO: Analysis of how final_vel_mean and final_vel_var depend on z (since they all share the same x)

print z_samples[3, :, :].mean(axis=0)
print z_samples[3, :, :].var(axis=0)
print v_samples[3, :, :].mean(axis=0)
print v_samples[3, :, :].var(axis=0)
print f_v_init_var(np.array([X[X_index, :]]))

print final_vel_nll.mean()
plt.boxplot(final_vel_nll, whis=1)
plt.show()
In [None]:
centers = np.zeros((10,n_latents))
stddevs = np.zeros((10,n_latents))
centers_init = np.zeros((10,n_latents))
stddevs_init = np.zeros((10,n_latents))
for i in range(10):
    Li = f_z_sample(X[Z.argmax(1) == i])
    centers[i] = Li.mean(axis=0)
    stddevs[i] = np.std(Li, axis=0)
    
    Li_init = f_z_init_sample(X[Z.argmax(1) == i])
    centers_init[i] = Li_init.mean(axis=0)
    stddevs_init[i] = np.std(Li_init, axis=0)
In [None]:
fig, axs = plt.subplots(1, 2, figsize=(18, 9))
axs[0].scatter(centers[:, dim1], centers[:, dim2], c=range(10), s=50)
axs[0].scatter(centers_init[:, dim1], centers_init[:, dim2], c=range(10), s=50, marker=u's')

axs[1].scatter(centers[:, dim1], centers[:, dim2], c=range(10), s=50)
axs[1].scatter(centers[:, dim1] + stddevs[:, dim1], centers[:, dim2], c=range(10), s=50, marker=u'>')
axs[1].scatter(centers[:, dim1] - stddevs[:, dim1], centers[:, dim2], c=range(10), s=50, marker=u'<')
axs[1].scatter(centers[:, dim1], centers[:, dim2] + stddevs[:, dim2], c=range(10), s=50, marker=u'^')
axs[1].scatter(centers[:, dim1], centers[:, dim2] - stddevs[:, dim2], c=range(10), s=50, marker=u'v')

#axs[0].set_xlim(-1.2, 1.2)
#axs[0].set_ylim(-1.2, 1.2)
#axs[1].set_xlim(-1.2, 1.2)
#axs[1].set_ylim(-1.2, 1.2)

print (centers[:, dim1] - centers_init[:, dim1])
print (centers[:, dim2] - centers_init[:, dim2])
print (stddevs[:, dim1] - stddevs_init[:, dim1])
print (stddevs[:, dim2] - stddevs_init[:, dim2])
In [None]: